import os    # nopep8
import sys   # nopep8
sys.path.append(os.path.join(os.path.dirname(__file__), 'vigs'))   # nopep8
import time
import torch
import cv2
import re
import os
import argparse
import numpy as np
import lietorch
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (100000, rlimit[1]))

from tqdm import tqdm
from torch.multiprocessing import Process, Queue
from vigs import VIGS

def get_tstamps_full(imagedir, start, length, stride):
    
    if 'stuttgart' in imagedir.lower():
        tstamps_full = np.array([float(re.findall(r"[+]?(?:\d*\.\d+|\d+)", x)[-1]) for x in sorted(os.listdir(imagedir))], dtype=np.float64)[..., np.newaxis]
    else:
        tstamps_full = np.array(sorted([float(re.findall(r"[+]?(?:\d*\.\d+|\d+)", x)[-1]) for x in sorted(os.listdir(imagedir))]), dtype=np.float64)[..., np.newaxis]/1e9
    tstamps_full = tstamps_full[start:start+length][::stride]
    return tstamps_full

def show_image(image, depth_prior, depth, normal):
    from util.utils import colorize_np
    image = image[[2,1,0]].permute(1, 2, 0).cpu().numpy()
    depth = colorize_np(np.concatenate((depth_prior.cpu().numpy(), depth.cpu().numpy()), axis=1), range=(0, 4))
    normal = normal.permute(1, 2, 0).cpu().numpy()
    cv2.imshow('rgb / prior normal / aligned prior depth / JDSA depth', np.concatenate((image / 255.0, (normal[...,[2,1,0]]+1.)/2., depth), axis=1)[::2,::2])
    cv2.waitKey(1)

def mono_stream(queue, imagedir, calib, undistort=False, cropborder=False, start=0, length=100000, stride=1):
    """ image generator """
    RES = 341 * 640
    calib = np.loadtxt(calib, delimiter=" ")
    K = np.array([[calib[0], 0, calib[2]],[0, calib[1], calib[3]],[0,0,1]])
    try:
        # NOTE: must be float, since some filename is in s, some in ns
        image_list = sorted(os.listdir(imagedir), key=lambda x: float(os.path.basename(x)[:-4]))[start:start+length][::stride]
    except:
        # if not purely numerical timestamp in the filename
        image_list = sorted(os.listdir(imagedir))[start:start+length][::stride]
    
    for t, imfile in enumerate(image_list):
        if 'stuttgart' in imagedir.lower():
            timestamp = float(re.findall(r"[+]?(?:\d*\.\d+|\d+)", imfile)[-1])
        else:
            timestamp = float(re.findall(r"[+]?(?:\d*\.\d+|\d+)", imfile)[-1]) / 1e9

        image = cv2.imread(os.path.join(imagedir, imfile))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        intrinsics = torch.tensor(calib[:4])
        if len(calib) > 4 and undistort:
            image = cv2.undistort(image, K, calib[4:])
            
        if cropborder > 0:
            image = image[cropborder:-cropborder, cropborder:-cropborder]
            intrinsics[2:] -= cropborder

        h0, w0, _ = image.shape
        h1 = int(h0 * np.sqrt((RES) / (h0 * w0)))
        w1 = int(w0 * np.sqrt((RES) / (h0 * w0)))
        h1 = h1 - h1 % 8
        w1 = w1 - w1 % 8
        image = cv2.resize(image, (w1, h1))
        image = torch.as_tensor(image).permute(2, 0, 1)
        intrinsics[[0,2]] *= (w1 / w0)
        intrinsics[[1,3]] *= (h1 / h0)
        is_last = (t == len(image_list)-1)
        queue.put((t, timestamp, image[None], intrinsics[None], is_last))

    time.sleep(20) 


def save_trajectory(vigs, traj_full, imagedir, output, start=0, length=100000, stride=1, final=False, tstamps_full=None, suffix=''):
    t = vigs.video.counter.value
    tstamps = vigs.video.tstamp[:t]
    poses_wc = lietorch.SE3(vigs.video.poses[:t]).inv().data
    if final:
        np.save("{}/intrinsics.npy".format(output), vigs.video.intrinsics[0].cpu().numpy()*8)
    if tstamps_full is None:
        tstamps_full = get_tstamps_full(imagedir, start, length, stride)
    tstamps_kf = tstamps_full[tstamps.cpu().numpy().astype(int)]
    ttraj_kf = np.concatenate([tstamps_kf, poses_wc.cpu().numpy()], axis=1)
    if final:
        np.savetxt(f"{output}/traj_kf{suffix}.txt", ttraj_kf)                     #  for evo evaluation 
    else:
        os.makedirs(f"{output}/traj", exist_ok=True)
        np.savetxt(f"{output}/traj/traj_kf{suffix}_{t:04d}.txt", ttraj_kf)                     #  for evo evaluation 
    if traj_full is not None:
        ttraj_full = np.concatenate([tstamps_full[:len(traj_full)], traj_full], axis=1)
        np.savetxt(f"{output}/traj_full{suffix}.txt", ttraj_full)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--imagedir", type=str, help="path to image directory")
    parser.add_argument("--imufile", type=str, help="path to imu measurement file")
    parser.add_argument("--calib", type=str, help="path to calibration file")
    parser.add_argument("--config", type=str, help="path to configuration file")
    parser.add_argument("--output", default='outputs/demo', help="path to save output")
    parser.add_argument("--gtdepthdir", type=str, default=None, help="optional for evaluation, assumes 16-bit depth scaled by 6553.5")
    
    parser.add_argument("--stride", default=1, type=int, help="frame stride")
    
    parser.add_argument("--weights", default=os.path.join(os.path.dirname(__file__), "pretrained_models/droid.pth"))
    parser.add_argument("--buffer", type=int, default=-1, help="number of keyframes to buffer (default: 1/10 of total frames)")
    parser.add_argument("--undistort", action="store_true", help="undistort images if calib file contains distortion parameters")
    parser.add_argument("--cropborder", type=int, default=0, help="crop images to remove black border")

    parser.add_argument("--droidvis", action="store_true")
    parser.add_argument("--rerunvis", action="store_true")
    parser.add_argument("--gsvis", action="store_true")
    parser.add_argument("--gsmapping", action="store_true")
    parser.add_argument("--final_ba_inertial", action="store_true")

    parser.add_argument("--start", type=int, default=0, help="start frame")
    parser.add_argument("--length", type=int, default=100000, help="number of frames to process")
    parser.add_argument("--IMU_poseinit_after", type=int, default=100000, help="enable IMU init after this keyframe index")
    args = parser.parse_args()
    
    if torch.cuda.is_available():
        print("GPU Available:", torch.cuda.get_device_name(0))
    
    os.makedirs(args.output, exist_ok=True)
    try:
        args.imus = np.loadtxt(args.imufile, delimiter=',') if args.imufile is not None else None
    except:
        args.imus = np.loadtxt(args.imufile, delimiter=' ') if args.imufile is not None else None
    
    os.system(f"cp {args.config} {args.output}/config.yaml")
    torch.multiprocessing.set_start_method('spawn')

    vigs = None
    queue = Queue(maxsize=8)

    reader = Process(target=mono_stream, args=(queue, args.imagedir, args.calib, args.undistort, args.cropborder, args.start, args.length, args.stride))
    reader.start()

    N = len(os.listdir(args.imagedir))
    # args.buffer = min(1000, N // 5 + 150) if args.buffer < 0 else args.buffer
    args.buffer = 1500 #1200  #brute force 700 since some sequence will out of buffer size #for aria, some seq need 3000
    pbar = tqdm(range(N), desc="Processing keyframes")
    tstamps_full = get_tstamps_full(args.imagedir, args.start, args.length, args.stride)
    while 1:
        (t, timestamp, image, intrinsics, is_last) = queue.get()
        pbar.update()
        
        if vigs is None:
            args.image_size = [image.shape[2], image.shape[3]]
            vigs = VIGS(args)

        vigs.track(t, timestamp, image, intrinsics=intrinsics, is_last=is_last)
        
        if args.gsmapping:
            pbar.set_description(f"Processing keyframe {vigs.video.counter.value} gs {vigs.gs.gaussians._xyz.shape[0]}")
        else:
            pbar.set_description(f"Processing keyframe {vigs.video.counter.value}")
        if is_last:
            pbar.close()
            break

    reader.join()
    if hasattr(vigs, 'mp_backend'):
        vigs.video.pgobuf.stop()
        vigs.mp_backend.join(timeout=1.0)
        
    traj_full_beforeBA = vigs.traj_filler(vigs.images)
    save_trajectory(vigs, traj_full_beforeBA.inv().data.cpu().numpy(), args.imagedir, args.output, start=args.start, length=args.length, stride=args.stride, final=True, suffix='_beforeBA')
    vigs.gs.gaussians.save_ply(f'{args.output}/3dgs_before_final.ply')
    traj = vigs.terminate(inertial=args.final_ba_inertial & vigs.video.IMU_initialized)
    vigs.gs.gaussians.save_ply(f'{args.output}/3dgs_final.ply')
    save_trajectory(vigs, traj, args.imagedir, args.output, start=args.start, length=args.length, stride=args.stride, final=True, suffix='_afterBA')
    print("Done")
